# -*-Python-*-
# 64-way data-parallelism, 8-way model-parallelism on tpu 16x16

utils.run.mesh_shape = "batch:64,model:8"
utils.run.layout_rules = "batch:batch,d_ff:model,heads:model,vocab:model"
